{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bxqOyLRaUkYW"
   },
   "source": [
    "# ObJAX CIFAR10 example\n",
    "\n",
    "This example is based on [cifar10_simple.py](https://github.com/google/objax/blob/master/examples/classify/img/cifar10_simple.py) with few minor changes:\n",
    "\n",
    "* it demonstrates how to do weight decay,\n",
    "* it uses Momentum optimizer with learning rate schedule,\n",
    "* it uses `tensorflow_datasets` instead of `Keras` dataset.\n",
    "\n",
    "It's recommended to run this notebook on GPU. In Google Colab this could be set through `Runtime -> Change runtime type` menu."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zYU-k0VsVRbL"
   },
   "source": [
    "# Installation and Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "2dg5WFQApM1_"
   },
   "outputs": [],
   "source": [
    "%pip --quiet install objax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "oizwGGdIUMr0"
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow_datasets as tfds\n",
    "\n",
    "import objax\n",
    "from objax.zoo.wide_resnet import WideResNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FJ65n3dQVvll"
   },
   "source": [
    "## Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "JnzqFRPpVwZ5"
   },
   "outputs": [],
   "source": [
    "base_learning_rate = 0.1 # Learning rate\n",
    "lr_decay_epochs = 30     # How often to decay learning rate\n",
    "lr_decay_factor = 0.2    # By how much to decay learning rate\n",
    "weight_decay =  0.0005   # Weight decay\n",
    "batch_size = 128         # Batch size\n",
    "num_train_epochs = 100   # Number of training epochs\n",
    "wrn_width = 2            # Width of WideResNet\n",
    "wrn_depth = 28           # Depth of WideResNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Uu5ildnhWF3M"
   },
   "source": [
    "# Setup dataset and model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mBQ-xNSDWIsj"
   },
   "outputs": [],
   "source": [
    "# Augmentation function for input data\n",
    "def augment(x):  # x is NCHW\n",
    "  \"\"\"Random flip and random shift augmentation of image batch.\"\"\"\n",
    "  if random.random() < .5:\n",
    "    x = x[:, :, :, ::-1]  # Flip the batch images about the horizontal axis\n",
    "  # Pixel-shift all images in the batch by up to 4 pixels in any direction.\n",
    "  x_pad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'reflect')\n",
    "  rx, ry = np.random.randint(0, 4), np.random.randint(0, 4)\n",
    "  x = x_pad[:, :, rx:rx + 32, ry:ry + 32]\n",
    "  return x\n",
    "\n",
    "# Data\n",
    "data = tfds.as_numpy(tfds.load(name='cifar10', batch_size=-1))\n",
    "x_train = data['train']['image'].transpose(0, 3, 1, 2) / 255.0\n",
    "y_train = data['train']['label']\n",
    "x_test = data['test']['image'].transpose(0, 3, 1, 2) / 255.0\n",
    "y_test = data['test']['label']\n",
    "del data\n",
    "\n",
    "# Model\n",
    "model = WideResNet(nin=3, nclass=10, depth=wrn_depth, width=wrn_width)\n",
    "weight_decay_vars = [v for k, v in model.vars().items() if k.endswith('.w')]\n",
    "\n",
    "# Optimizer\n",
    "opt = objax.optimizer.Momentum(model.vars(), nesterov=True)\n",
    "\n",
    "# Prediction operation\n",
    "predict_op = objax.nn.Sequential([objax.ForceArgs(model, training=False), objax.functional.softmax])\n",
    "predict_op = objax.Jit(predict_op)\n",
    "\n",
    "# Loss and training op\n",
    "@objax.Function.with_vars(model.vars())\n",
    "def loss_fn(x, label):\n",
    "  logit = model(x, training=True)\n",
    "  xe_loss = objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()\n",
    "  wd_loss = sum((v ** 2).sum() for v in weight_decay_vars)\n",
    "  return xe_loss + weight_decay * wd_loss\n",
    "\n",
    "loss_gv = objax.GradValues(loss_fn, model.vars())\n",
    "\n",
    "@objax.Function.with_vars(model.vars() + loss_gv.vars() + opt.vars())\n",
    "def train_op(x, y, learning_rate):\n",
    "    grads, loss = loss_gv(x, y)\n",
    "    opt(learning_rate, grads)\n",
    "    return loss\n",
    "\n",
    "train_op = objax.Jit(train_op)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Q6KUOcDxC7gq"
   },
   "source": [
    "**Model parameters**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "uMzCIgeJC-Wu",
    "outputId": "346b5364-0a56-4099-bbb9-78fcba908311"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(WideResNet)[0](Conv2D).w                                        432 (3, 3, 3, 16)\n",
      "(WideResNet)[1](WRNBlock).proj_conv(Conv2D).w                    512 (1, 1, 16, 32)\n",
      "(WideResNet)[1](WRNBlock).norm_1(BatchNorm2D).running_mean        16 (1, 16, 1, 1)\n",
      "(WideResNet)[1](WRNBlock).norm_1(BatchNorm2D).running_var         16 (1, 16, 1, 1)\n",
      "(WideResNet)[1](WRNBlock).norm_1(BatchNorm2D).beta                16 (1, 16, 1, 1)\n",
      "(WideResNet)[1](WRNBlock).norm_1(BatchNorm2D).gamma               16 (1, 16, 1, 1)\n",
      "(WideResNet)[1](WRNBlock).conv_1(Conv2D).w                      4608 (3, 3, 16, 32)\n",
      "(WideResNet)[1](WRNBlock).norm_2(BatchNorm2D).running_mean        32 (1, 32, 1, 1)\n",
      "(WideResNet)[1](WRNBlock).norm_2(BatchNorm2D).running_var         32 (1, 32, 1, 1)\n",
      "(WideResNet)[1](WRNBlock).norm_2(BatchNorm2D).beta                32 (1, 32, 1, 1)\n",
      "(WideResNet)[1](WRNBlock).norm_2(BatchNorm2D).gamma               32 (1, 32, 1, 1)\n",
      "(WideResNet)[1](WRNBlock).conv_2(Conv2D).w                      9216 (3, 3, 32, 32)\n",
      "(WideResNet)[2](WRNBlock).norm_1(BatchNorm2D).running_mean        32 (1, 32, 1, 1)\n",
      "(WideResNet)[2](WRNBlock).norm_1(BatchNorm2D).running_var         32 (1, 32, 1, 1)\n",
      "(WideResNet)[2](WRNBlock).norm_1(BatchNorm2D).beta                32 (1, 32, 1, 1)\n",
      "(WideResNet)[2](WRNBlock).norm_1(BatchNorm2D).gamma               32 (1, 32, 1, 1)\n",
      "(WideResNet)[2](WRNBlock).conv_1(Conv2D).w                      9216 (3, 3, 32, 32)\n",
      "(WideResNet)[2](WRNBlock).norm_2(BatchNorm2D).running_mean        32 (1, 32, 1, 1)\n",
      "(WideResNet)[2](WRNBlock).norm_2(BatchNorm2D).running_var         32 (1, 32, 1, 1)\n",
      "(WideResNet)[2](WRNBlock).norm_2(BatchNorm2D).beta                32 (1, 32, 1, 1)\n",
      "(WideResNet)[2](WRNBlock).norm_2(BatchNorm2D).gamma               32 (1, 32, 1, 1)\n",
      "(WideResNet)[2](WRNBlock).conv_2(Conv2D).w                      9216 (3, 3, 32, 32)\n",
      "(WideResNet)[3](WRNBlock).norm_1(BatchNorm2D).running_mean        32 (1, 32, 1, 1)\n",
      "(WideResNet)[3](WRNBlock).norm_1(BatchNorm2D).running_var         32 (1, 32, 1, 1)\n",
      "(WideResNet)[3](WRNBlock).norm_1(BatchNorm2D).beta                32 (1, 32, 1, 1)\n",
      "(WideResNet)[3](WRNBlock).norm_1(BatchNorm2D).gamma               32 (1, 32, 1, 1)\n",
      "(WideResNet)[3](WRNBlock).conv_1(Conv2D).w                      9216 (3, 3, 32, 32)\n",
      "(WideResNet)[3](WRNBlock).norm_2(BatchNorm2D).running_mean        32 (1, 32, 1, 1)\n",
      "(WideResNet)[3](WRNBlock).norm_2(BatchNorm2D).running_var         32 (1, 32, 1, 1)\n",
      "(WideResNet)[3](WRNBlock).norm_2(BatchNorm2D).beta                32 (1, 32, 1, 1)\n",
      "(WideResNet)[3](WRNBlock).norm_2(BatchNorm2D).gamma               32 (1, 32, 1, 1)\n",
      "(WideResNet)[3](WRNBlock).conv_2(Conv2D).w                      9216 (3, 3, 32, 32)\n",
      "(WideResNet)[4](WRNBlock).norm_1(BatchNorm2D).running_mean        32 (1, 32, 1, 1)\n",
      "(WideResNet)[4](WRNBlock).norm_1(BatchNorm2D).running_var         32 (1, 32, 1, 1)\n",
      "(WideResNet)[4](WRNBlock).norm_1(BatchNorm2D).beta                32 (1, 32, 1, 1)\n",
      "(WideResNet)[4](WRNBlock).norm_1(BatchNorm2D).gamma               32 (1, 32, 1, 1)\n",
      "(WideResNet)[4](WRNBlock).conv_1(Conv2D).w                      9216 (3, 3, 32, 32)\n",
      "(WideResNet)[4](WRNBlock).norm_2(BatchNorm2D).running_mean        32 (1, 32, 1, 1)\n",
      "(WideResNet)[4](WRNBlock).norm_2(BatchNorm2D).running_var         32 (1, 32, 1, 1)\n",
      "(WideResNet)[4](WRNBlock).norm_2(BatchNorm2D).beta                32 (1, 32, 1, 1)\n",
      "(WideResNet)[4](WRNBlock).norm_2(BatchNorm2D).gamma               32 (1, 32, 1, 1)\n",
      "(WideResNet)[4](WRNBlock).conv_2(Conv2D).w                      9216 (3, 3, 32, 32)\n",
      "(WideResNet)[5](WRNBlock).proj_conv(Conv2D).w                   2048 (1, 1, 32, 64)\n",
      "(WideResNet)[5](WRNBlock).norm_1(BatchNorm2D).running_mean        32 (1, 32, 1, 1)\n",
      "(WideResNet)[5](WRNBlock).norm_1(BatchNorm2D).running_var         32 (1, 32, 1, 1)\n",
      "(WideResNet)[5](WRNBlock).norm_1(BatchNorm2D).beta                32 (1, 32, 1, 1)\n",
      "(WideResNet)[5](WRNBlock).norm_1(BatchNorm2D).gamma               32 (1, 32, 1, 1)\n",
      "(WideResNet)[5](WRNBlock).conv_1(Conv2D).w                     18432 (3, 3, 32, 64)\n",
      "(WideResNet)[5](WRNBlock).norm_2(BatchNorm2D).running_mean        64 (1, 64, 1, 1)\n",
      "(WideResNet)[5](WRNBlock).norm_2(BatchNorm2D).running_var         64 (1, 64, 1, 1)\n",
      "(WideResNet)[5](WRNBlock).norm_2(BatchNorm2D).beta                64 (1, 64, 1, 1)\n",
      "(WideResNet)[5](WRNBlock).norm_2(BatchNorm2D).gamma               64 (1, 64, 1, 1)\n",
      "(WideResNet)[5](WRNBlock).conv_2(Conv2D).w                     36864 (3, 3, 64, 64)\n",
      "(WideResNet)[6](WRNBlock).norm_1(BatchNorm2D).running_mean        64 (1, 64, 1, 1)\n",
      "(WideResNet)[6](WRNBlock).norm_1(BatchNorm2D).running_var         64 (1, 64, 1, 1)\n",
      "(WideResNet)[6](WRNBlock).norm_1(BatchNorm2D).beta                64 (1, 64, 1, 1)\n",
      "(WideResNet)[6](WRNBlock).norm_1(BatchNorm2D).gamma               64 (1, 64, 1, 1)\n",
      "(WideResNet)[6](WRNBlock).conv_1(Conv2D).w                     36864 (3, 3, 64, 64)\n",
      "(WideResNet)[6](WRNBlock).norm_2(BatchNorm2D).running_mean        64 (1, 64, 1, 1)\n",
      "(WideResNet)[6](WRNBlock).norm_2(BatchNorm2D).running_var         64 (1, 64, 1, 1)\n",
      "(WideResNet)[6](WRNBlock).norm_2(BatchNorm2D).beta                64 (1, 64, 1, 1)\n",
      "(WideResNet)[6](WRNBlock).norm_2(BatchNorm2D).gamma               64 (1, 64, 1, 1)\n",
      "(WideResNet)[6](WRNBlock).conv_2(Conv2D).w                     36864 (3, 3, 64, 64)\n",
      "(WideResNet)[7](WRNBlock).norm_1(BatchNorm2D).running_mean        64 (1, 64, 1, 1)\n",
      "(WideResNet)[7](WRNBlock).norm_1(BatchNorm2D).running_var         64 (1, 64, 1, 1)\n",
      "(WideResNet)[7](WRNBlock).norm_1(BatchNorm2D).beta                64 (1, 64, 1, 1)\n",
      "(WideResNet)[7](WRNBlock).norm_1(BatchNorm2D).gamma               64 (1, 64, 1, 1)\n",
      "(WideResNet)[7](WRNBlock).conv_1(Conv2D).w                     36864 (3, 3, 64, 64)\n",
      "(WideResNet)[7](WRNBlock).norm_2(BatchNorm2D).running_mean        64 (1, 64, 1, 1)\n",
      "(WideResNet)[7](WRNBlock).norm_2(BatchNorm2D).running_var         64 (1, 64, 1, 1)\n",
      "(WideResNet)[7](WRNBlock).norm_2(BatchNorm2D).beta                64 (1, 64, 1, 1)\n",
      "(WideResNet)[7](WRNBlock).norm_2(BatchNorm2D).gamma               64 (1, 64, 1, 1)\n",
      "(WideResNet)[7](WRNBlock).conv_2(Conv2D).w                     36864 (3, 3, 64, 64)\n",
      "(WideResNet)[8](WRNBlock).norm_1(BatchNorm2D).running_mean        64 (1, 64, 1, 1)\n",
      "(WideResNet)[8](WRNBlock).norm_1(BatchNorm2D).running_var         64 (1, 64, 1, 1)\n",
      "(WideResNet)[8](WRNBlock).norm_1(BatchNorm2D).beta                64 (1, 64, 1, 1)\n",
      "(WideResNet)[8](WRNBlock).norm_1(BatchNorm2D).gamma               64 (1, 64, 1, 1)\n",
      "(WideResNet)[8](WRNBlock).conv_1(Conv2D).w                     36864 (3, 3, 64, 64)\n",
      "(WideResNet)[8](WRNBlock).norm_2(BatchNorm2D).running_mean        64 (1, 64, 1, 1)\n",
      "(WideResNet)[8](WRNBlock).norm_2(BatchNorm2D).running_var         64 (1, 64, 1, 1)\n",
      "(WideResNet)[8](WRNBlock).norm_2(BatchNorm2D).beta                64 (1, 64, 1, 1)\n",
      "(WideResNet)[8](WRNBlock).norm_2(BatchNorm2D).gamma               64 (1, 64, 1, 1)\n",
      "(WideResNet)[8](WRNBlock).conv_2(Conv2D).w                     36864 (3, 3, 64, 64)\n",
      "(WideResNet)[9](WRNBlock).proj_conv(Conv2D).w                   8192 (1, 1, 64, 128)\n",
      "(WideResNet)[9](WRNBlock).norm_1(BatchNorm2D).running_mean        64 (1, 64, 1, 1)\n",
      "(WideResNet)[9](WRNBlock).norm_1(BatchNorm2D).running_var         64 (1, 64, 1, 1)\n",
      "(WideResNet)[9](WRNBlock).norm_1(BatchNorm2D).beta                64 (1, 64, 1, 1)\n",
      "(WideResNet)[9](WRNBlock).norm_1(BatchNorm2D).gamma               64 (1, 64, 1, 1)\n",
      "(WideResNet)[9](WRNBlock).conv_1(Conv2D).w                     73728 (3, 3, 64, 128)\n",
      "(WideResNet)[9](WRNBlock).norm_2(BatchNorm2D).running_mean       128 (1, 128, 1, 1)\n",
      "(WideResNet)[9](WRNBlock).norm_2(BatchNorm2D).running_var        128 (1, 128, 1, 1)\n",
      "(WideResNet)[9](WRNBlock).norm_2(BatchNorm2D).beta               128 (1, 128, 1, 1)\n",
      "(WideResNet)[9](WRNBlock).norm_2(BatchNorm2D).gamma              128 (1, 128, 1, 1)\n",
      "(WideResNet)[9](WRNBlock).conv_2(Conv2D).w                    147456 (3, 3, 128, 128)\n",
      "(WideResNet)[10](WRNBlock).norm_1(BatchNorm2D).running_mean      128 (1, 128, 1, 1)\n",
      "(WideResNet)[10](WRNBlock).norm_1(BatchNorm2D).running_var       128 (1, 128, 1, 1)\n",
      "(WideResNet)[10](WRNBlock).norm_1(BatchNorm2D).beta              128 (1, 128, 1, 1)\n",
      "(WideResNet)[10](WRNBlock).norm_1(BatchNorm2D).gamma             128 (1, 128, 1, 1)\n",
      "(WideResNet)[10](WRNBlock).conv_1(Conv2D).w                   147456 (3, 3, 128, 128)\n",
      "(WideResNet)[10](WRNBlock).norm_2(BatchNorm2D).running_mean      128 (1, 128, 1, 1)\n",
      "(WideResNet)[10](WRNBlock).norm_2(BatchNorm2D).running_var       128 (1, 128, 1, 1)\n",
      "(WideResNet)[10](WRNBlock).norm_2(BatchNorm2D).beta              128 (1, 128, 1, 1)\n",
      "(WideResNet)[10](WRNBlock).norm_2(BatchNorm2D).gamma             128 (1, 128, 1, 1)\n",
      "(WideResNet)[10](WRNBlock).conv_2(Conv2D).w                   147456 (3, 3, 128, 128)\n",
      "(WideResNet)[11](WRNBlock).norm_1(BatchNorm2D).running_mean      128 (1, 128, 1, 1)\n",
      "(WideResNet)[11](WRNBlock).norm_1(BatchNorm2D).running_var       128 (1, 128, 1, 1)\n",
      "(WideResNet)[11](WRNBlock).norm_1(BatchNorm2D).beta              128 (1, 128, 1, 1)\n",
      "(WideResNet)[11](WRNBlock).norm_1(BatchNorm2D).gamma             128 (1, 128, 1, 1)\n",
      "(WideResNet)[11](WRNBlock).conv_1(Conv2D).w                   147456 (3, 3, 128, 128)\n",
      "(WideResNet)[11](WRNBlock).norm_2(BatchNorm2D).running_mean      128 (1, 128, 1, 1)\n",
      "(WideResNet)[11](WRNBlock).norm_2(BatchNorm2D).running_var       128 (1, 128, 1, 1)\n",
      "(WideResNet)[11](WRNBlock).norm_2(BatchNorm2D).beta              128 (1, 128, 1, 1)\n",
      "(WideResNet)[11](WRNBlock).norm_2(BatchNorm2D).gamma             128 (1, 128, 1, 1)\n",
      "(WideResNet)[11](WRNBlock).conv_2(Conv2D).w                   147456 (3, 3, 128, 128)\n",
      "(WideResNet)[12](WRNBlock).norm_1(BatchNorm2D).running_mean      128 (1, 128, 1, 1)\n",
      "(WideResNet)[12](WRNBlock).norm_1(BatchNorm2D).running_var       128 (1, 128, 1, 1)\n",
      "(WideResNet)[12](WRNBlock).norm_1(BatchNorm2D).beta              128 (1, 128, 1, 1)\n",
      "(WideResNet)[12](WRNBlock).norm_1(BatchNorm2D).gamma             128 (1, 128, 1, 1)\n",
      "(WideResNet)[12](WRNBlock).conv_1(Conv2D).w                   147456 (3, 3, 128, 128)\n",
      "(WideResNet)[12](WRNBlock).norm_2(BatchNorm2D).running_mean      128 (1, 128, 1, 1)\n",
      "(WideResNet)[12](WRNBlock).norm_2(BatchNorm2D).running_var       128 (1, 128, 1, 1)\n",
      "(WideResNet)[12](WRNBlock).norm_2(BatchNorm2D).beta              128 (1, 128, 1, 1)\n",
      "(WideResNet)[12](WRNBlock).norm_2(BatchNorm2D).gamma             128 (1, 128, 1, 1)\n",
      "(WideResNet)[12](WRNBlock).conv_2(Conv2D).w                   147456 (3, 3, 128, 128)\n",
      "(WideResNet)[13](BatchNorm2D).running_mean                       128 (1, 128, 1, 1)\n",
      "(WideResNet)[13](BatchNorm2D).running_var                        128 (1, 128, 1, 1)\n",
      "(WideResNet)[13](BatchNorm2D).beta                               128 (1, 128, 1, 1)\n",
      "(WideResNet)[13](BatchNorm2D).gamma                              128 (1, 128, 1, 1)\n",
      "(WideResNet)[16](Linear).b                                        10 (10,)\n",
      "(WideResNet)[16](Linear).w                                      1280 (128, 10)\n",
      "+Total(130)                                                  1471226\n"
     ]
    }
   ],
   "source": [
    "print(model.vars())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pQH7UyxxWf9j"
   },
   "source": [
    "# Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "KgnrJyv0Whah",
    "outputId": "bcb5a5eb-8c0f-4b42-bf24-fe35a1611097"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch   1 -- train loss 2.183   test accuracy 57.6\n",
      "Epoch   2 -- train loss 1.348   test accuracy 64.7\n",
      "Epoch   3 -- train loss 1.186   test accuracy 64.6\n",
      "Epoch   4 -- train loss 0.975   test accuracy 69.7\n",
      "Epoch   5 -- train loss 1.114   test accuracy 68.5\n",
      "Epoch   6 -- train loss 0.837   test accuracy 68.8\n",
      "Epoch   7 -- train loss 0.871   test accuracy 74.4\n",
      "Epoch   8 -- train loss 0.996   test accuracy 74.8\n",
      "Epoch   9 -- train loss 0.764   test accuracy 79.1\n",
      "Epoch  10 -- train loss 0.813   test accuracy 77.6\n",
      "Epoch  11 -- train loss 0.933   test accuracy 70.4\n",
      "Epoch  12 -- train loss 0.937   test accuracy 65.5\n",
      "Epoch  13 -- train loss 0.814   test accuracy 79.5\n",
      "Epoch  14 -- train loss 0.638   test accuracy 76.2\n",
      "Epoch  15 -- train loss 0.984   test accuracy 77.6\n",
      "Epoch  16 -- train loss 0.844   test accuracy 79.8\n",
      "Epoch  17 -- train loss 0.690   test accuracy 77.9\n",
      "Epoch  18 -- train loss 0.692   test accuracy 77.4\n",
      "Epoch  19 -- train loss 0.706   test accuracy 76.2\n",
      "Epoch  20 -- train loss 0.786   test accuracy 69.9\n",
      "Epoch  21 -- train loss 0.862   test accuracy 79.9\n",
      "Epoch  22 -- train loss 0.691   test accuracy 73.6\n",
      "Epoch  23 -- train loss 1.024   test accuracy 81.1\n",
      "Epoch  24 -- train loss 0.912   test accuracy 79.5\n",
      "Epoch  25 -- train loss 0.704   test accuracy 78.7\n",
      "Epoch  26 -- train loss 0.765   test accuracy 72.9\n",
      "Epoch  27 -- train loss 0.751   test accuracy 78.9\n",
      "Epoch  28 -- train loss 0.693   test accuracy 78.6\n",
      "Epoch  29 -- train loss 0.901   test accuracy 80.5\n",
      "Epoch  30 -- train loss 0.741   test accuracy 78.6\n",
      "Epoch  31 -- train loss 0.588   test accuracy 89.5\n",
      "Epoch  32 -- train loss 0.463   test accuracy 89.7\n",
      "Epoch  33 -- train loss 0.361   test accuracy 88.8\n",
      "Epoch  34 -- train loss 0.519   test accuracy 88.8\n",
      "Epoch  35 -- train loss 0.448   test accuracy 89.5\n",
      "Epoch  36 -- train loss 0.404   test accuracy 88.5\n",
      "Epoch  37 -- train loss 0.327   test accuracy 89.0\n",
      "Epoch  38 -- train loss 0.323   test accuracy 88.8\n",
      "Epoch  39 -- train loss 0.425   test accuracy 86.2\n",
      "Epoch  40 -- train loss 0.354   test accuracy 88.2\n",
      "Epoch  41 -- train loss 0.410   test accuracy 88.7\n",
      "Epoch  42 -- train loss 0.500   test accuracy 88.0\n",
      "Epoch  43 -- train loss 0.438   test accuracy 88.8\n",
      "Epoch  44 -- train loss 0.304   test accuracy 88.9\n",
      "Epoch  45 -- train loss 0.358   test accuracy 87.2\n",
      "Epoch  46 -- train loss 0.390   test accuracy 85.2\n",
      "Epoch  47 -- train loss 0.474   test accuracy 87.6\n",
      "Epoch  48 -- train loss 0.514   test accuracy 88.2\n",
      "Epoch  49 -- train loss 0.411   test accuracy 88.6\n",
      "Epoch  50 -- train loss 0.474   test accuracy 88.4\n",
      "Epoch  51 -- train loss 0.388   test accuracy 89.2\n",
      "Epoch  52 -- train loss 0.598   test accuracy 87.8\n",
      "Epoch  53 -- train loss 0.479   test accuracy 87.9\n",
      "Epoch  54 -- train loss 0.325   test accuracy 88.7\n",
      "Epoch  55 -- train loss 0.562   test accuracy 87.9\n",
      "Epoch  56 -- train loss 0.368   test accuracy 89.3\n",
      "Epoch  57 -- train loss 0.393   test accuracy 89.5\n",
      "Epoch  58 -- train loss 0.444   test accuracy 88.7\n",
      "Epoch  59 -- train loss 0.487   test accuracy 88.4\n",
      "Epoch  60 -- train loss 0.459   test accuracy 86.6\n",
      "Epoch  61 -- train loss 0.279   test accuracy 92.3\n",
      "Epoch  62 -- train loss 0.286   test accuracy 92.2\n",
      "Epoch  63 -- train loss 0.262   test accuracy 92.6\n",
      "Epoch  64 -- train loss 0.273   test accuracy 92.7\n",
      "Epoch  65 -- train loss 0.254   test accuracy 92.6\n",
      "Epoch  66 -- train loss 0.224   test accuracy 92.7\n",
      "Epoch  67 -- train loss 0.300   test accuracy 92.5\n",
      "Epoch  68 -- train loss 0.253   test accuracy 92.4\n",
      "Epoch  69 -- train loss 0.203   test accuracy 92.7\n",
      "Epoch  70 -- train loss 0.243   test accuracy 92.7\n",
      "Epoch  71 -- train loss 0.204   test accuracy 92.5\n",
      "Epoch  72 -- train loss 0.217   test accuracy 92.2\n",
      "Epoch  73 -- train loss 0.198   test accuracy 92.4\n",
      "Epoch  74 -- train loss 0.182   test accuracy 92.6\n",
      "Epoch  75 -- train loss 0.183   test accuracy 92.3\n",
      "Epoch  76 -- train loss 0.170   test accuracy 92.7\n",
      "Epoch  77 -- train loss 0.177   test accuracy 92.3\n",
      "Epoch  78 -- train loss 0.169   test accuracy 92.2\n",
      "Epoch  79 -- train loss 0.180   test accuracy 92.3\n",
      "Epoch  80 -- train loss 0.180   test accuracy 92.3\n",
      "Epoch  81 -- train loss 0.187   test accuracy 92.4\n",
      "Epoch  82 -- train loss 0.225   test accuracy 92.1\n",
      "Epoch  83 -- train loss 0.150   test accuracy 92.2\n",
      "Epoch  84 -- train loss 0.168   test accuracy 92.0\n",
      "Epoch  85 -- train loss 0.141   test accuracy 91.8\n",
      "Epoch  86 -- train loss 0.144   test accuracy 91.8\n",
      "Epoch  87 -- train loss 0.142   test accuracy 92.1\n",
      "Epoch  88 -- train loss 0.141   test accuracy 91.6\n",
      "Epoch  89 -- train loss 0.194   test accuracy 91.6\n",
      "Epoch  90 -- train loss 0.135   test accuracy 91.0\n",
      "Epoch  91 -- train loss 0.131   test accuracy 92.7\n",
      "Epoch  92 -- train loss 0.135   test accuracy 93.0\n",
      "Epoch  93 -- train loss 0.129   test accuracy 93.1\n",
      "Epoch  94 -- train loss 0.133   test accuracy 93.1\n",
      "Epoch  95 -- train loss 0.140   test accuracy 93.2\n",
      "Epoch  96 -- train loss 0.130   test accuracy 93.2\n",
      "Epoch  97 -- train loss 0.128   test accuracy 93.2\n",
      "Epoch  98 -- train loss 0.124   test accuracy 93.1\n",
      "Epoch  99 -- train loss 0.127   test accuracy 93.1\n",
      "Epoch 100 -- train loss 0.124   test accuracy 93.2\n"
     ]
    }
   ],
   "source": [
    "def lr_schedule(epoch):\n",
    "  return base_learning_rate * math.pow(lr_decay_factor, epoch // lr_decay_epochs)\n",
    "\n",
    "num_train_examples = x_train.shape[0]\n",
    "num_test_examples = x_test.shape[0]\n",
    "for epoch in range(num_train_epochs):\n",
    "  # Training\n",
    "  example_indices = np.arange(num_train_examples)\n",
    "  np.random.shuffle(example_indices)\n",
    "  for idx in range(0, num_train_examples, batch_size):\n",
    "    x = x_train[example_indices[idx:idx + batch_size]]\n",
    "    y = y_train[example_indices[idx:idx + batch_size]]\n",
    "    loss = train_op(augment(x), y, lr_schedule(epoch))[0]\n",
    "\n",
    "  # Eval\n",
    "  accuracy = 0\n",
    "  for idx in range(0, num_test_examples, batch_size):\n",
    "    x = x_test[idx:idx + batch_size]\n",
    "    y = y_test[idx:idx + batch_size]\n",
    "    p = predict_op(x)\n",
    "    accuracy += (np.argmax(p, axis=1) == y).sum()\n",
    "  accuracy /= num_test_examples\n",
    "  print(f'Epoch {epoch+1:3} -- train loss {loss:.3f}   test accuracy {accuracy*100:.1f}', flush=True)\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "ObJAX_CIFAR10_example.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
